--- title: Categorical DQN keywords: fastai sidebar: home_sidebar summary: "An implimentation of a DQN that uses distributions to represent Q from the paper A Distributional Perspective on Reinforcement Learning" description: "An implimentation of a DQN that uses distributions to represent Q from the paper A Distributional Perspective on Reinforcement Learning" nb_path: "nbs/10e_agents.dqn.categorical.ipynb" ---
The Categorical DQN can be summarized as:
Instead of action outputs being single Q values, they are instead distributions of `N` size.
We start off with the idea of atoms and supports. A support acts as a mask over the output action distributions. This is illistrated by the equations and the corresponding functions.
We start with the equation...
$$ {\large Z_{\theta}(z,a) = z_i \quad w.p. \: p_i(x,a):= \frac{ e^{\theta_i(x,a)}} {\sum_j{e^{\theta_j(x,a)}}} } $$... which shows that the end of our neural net model needs to be squished to be a proper probability. It also defines $z_i$ which is a support of which, we will define very soon. Below is the implimentation of the right side equation for $p_i(x,a)$
An important note is that $\frac{ e^{\theta_i(x,a)}} {\sum_j{e^{\theta_j(x,a)}}} $ is just:
Softmax
We pretend that the output of the neural net is of shape (batch_sz,n_actions,n_atoms). In this instance,
there is only one action. This implies that $Z_{\theta}$ is just $z_0$.
out=Softmax(dim=1)(torch.randn(1,51,1))[0] # Action 0
plt.plot(out.numpy())
The next function describes how propabilities are calculated from the neural net output. The equation describes a $z_i$ which is explained by: $$ \{z_i = V_{min} + i\Delta z : 0 \leq i < N \}, \: \Delta z := \frac{V_{max} - V_{min}}{N - 1} $$
Where $V_{max}$, $V_{min}$, and $N$ are constants that we define. Note that $N$ is the number of atoms. So what does a $z_i$ look like? We will define this in code below...
import matplotlib.pyplot as plt
support_dist,z_delta=create_support()
print('z_delta: ',z_delta)
plt.plot(support_dist.numpy())
This is a single $z_i$ in $Z_{\theta}$. The number of $z_i$s is equal to the number of actions that the DQN is operating with.
{% include note.html content='Josiah: Is this always the case? Could there be only $z_0$ and multiple actions?' %}
Ok! Hopefully this wasn't too bad to go through. We basically normalized the neural net output to be nicer to deal with,
and created/initialized a (bunch) of increasing arrays that we are calling discrete distributions i.e. output from create_support.
Now for the fun part! We have this giant ass update equation:
$$ {\large (\Phi\hat{\mathcal{T}}Z_{\theta}(x,a))_i = \sum_{j=0}^{N-1} \left[ 1 - \frac{ | \mathcal{T}z_j |_{V_{min}}^{V_{max}} - z_i }{ \Delta z } \right]_0^1 p_j(x^{\prime},\pi(x^{\prime})) } $$Good god... and we also have
$$ \hat{\mathcal{T}}z_j := r + \gamma z_j $$where, to quote the paper:
I highly recommend reading pg6 in the paper for a fuller explaination. I was originally wondering what the difference was between $\pi$ and simple $\theta$, which the main difference is that $\pi$ is a greedy action selection i.e. we run argmax to get the action.
This was a lot! Luckily they have a re-formalation in algorithmic form:
def categorical_update(v_min,v_max,n_atoms,support,delta_z,model,reward,gamma,action,next_state):
t_q=(support*Softmax(model(next_state).gather(action))).sum()
a_star=torch.argmax(t_q)
m=torch.zeros((N,)) # m_i = 0 where i in 1,...,N-1
for j in range(n_atoms):
# Compute the projection of $ \hat{\mathcal{T}}z_j $ onto support $ z_j $
target_z=torch.clamp(reward+gamma*support[:,j],v_min,v_max)
b_j=(target_z-v_min)/delta_z # b_j in [0,N-1]
l=torch.floor(b_j)
u=torch.ceil(b_j)
# Distribute probability of $ \hat{\mathcal{T}}z_j $
m[:,l]=m[:,l]+a_star*(u-b)
m[:,u]=m[:,u]+a_star*(b-l)
return # Some cross entropy loss
There is a small problem with the above equation. This was a (fairly) literal convertion from Algorithm 1 in the paper to Python.
There are some problems here:
Lets rename these! We will instead have:
$$
m\_i \rightarrow projection\\
a\_star \rightarrow next\_action\\
b\_j \rightarrow support\_value\\
l \rightarrow support\_left\\
u \rightarrow support\_right\\
$$
So lets revise the problem and pretend that we have a 2 action model, batch size of 8, where the last element has a reward of 0, and where left actions are -1, while right actions are 1.
from torch.distributions.normal import Normal
So for a single action we would have a distribution like this...
plt.plot(Normal(0,1).sample((51,)).numpy())
So since our model has 2 actions that it can pick, we create some distributions for them...
dist_left=torch.vstack([Normal(0.5,1).sample((1,51)),Normal(0.5,0.1).sample((1,51))]).unsqueeze(0)
dist_right=torch.vstack([Normal(0.5,0.1).sample((1,51)),Normal(0.5,1).sample((1,51))]).unsqueeze(0)
(dist_left.shape,dist_right.shape)
...where the $[1, 2, 51]$ is $[batch, action, n\_atoms]$
model_out=torch.vstack([copy([dist_left,dist_right][i%2==0]) for i in range(1,9)]).to(device=default_device())
(model_out.shape)
summed_model_out=model_out.sum(dim=2);summed_model_out=Softmax(dim=1)(summed_model_out).to(device=default_device())
(summed_model_out.shape,summed_model_out)
So when we sum/normalize the distrubtions per batch, per action, we get an output that looks like your typical dqn output...
We can also treat this like a regular DQN and do an argmax to get actions like usual...
actions=torch.argmax(summed_model_out,dim=1).reshape(-1,1).to(device=default_device());actions
rewards=actions;rewards
dones=Tensor().new_zeros((8,1)).bool().to(device=default_device());dones[-1][0]=1;dones
So lets decompose the categorical_update above into something easier to read. First we will note the author's original algorithm:
{% include image.html width="500" height="500" max-width="500" file="/fastrl/docs/images/10e_agents.dqn.categorical_algorithm1.png" %}
We can break this into 3 different functions:
- getting the Q<br>
- calculating the update<br>
- calculating the loss
We will start with the $Q(x_{t+1},a):=\sum_iz_ip_i(x_{t_1},a))$
The CategoricalDQN.q function gets us 90% of the way to the equation above. However,
you will notice that that equation is for a specific action. We will handle this in the actual update function.
dqn=CategoricalDQN(4,2).to(device=default_device())
dqn(torch.randn(8,4).to(device=default_device())).shape
dqn.q(torch.randn(8,4).to(device=default_device()))
dqn.policy(torch.randn(8,4).to(device=default_device()))
output=categorical_update(dqn.supports,dqn.z_delta,summed_model_out,
Softmax(dim=2)(model_out),actions,rewards,dones,passes=None)
show_q_distribution(output)
q=dqn.q(torch.randn(8,4).to(device=default_device()))
p=dqn.p(torch.randn(8,4).to(device=default_device()))
output=categorical_update(dqn.supports,dqn.z_delta,q,p,actions,rewards,dones)
show_q_distribution(output,title='Real Model Update Distributions')
dqn=CategoricalDQN(4,2)
agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=1000,bs=1,num_workers=0)
learn=Learner(dls,agent,loss_func=PartialCrossEntropy,
cbs=[ExperienceReplay(bs=32,max_sz=100000,warmup_sz=32),CategoricalDQNTrainer(target_sync=300)],
metrics=[Reward,Epsilon])
full=True
learn.fit(47 if full else 3,lr=0.0001,wd=0)
from IPython.display import HTML
import plotly.express as px
learn.cbs[-1].local_pred.shape
learn.cbs[-1].local_v.shape
show_q(learn.cbs[-1].local_xb[0])
show_q(learn.cbs[-1].local_pred)
show_q(learn.cbs[-1].local_v[:,1,:])
show_q(learn.cbs[-1].local_v[:,0,:])
(-learn.cbs[-1].local_pred*learn.cbs[-1].local_xb[0]).sum(dim=1).mean()
show_q(-learn.cbs[-1].local_pred*learn.cbs[-1].local_xb[0])
from IPython.display import HTML
import plotly.express as px
agent=Agent(dqn,cbs=[CategoricalArgMaxFeed,DiscreteEpsilonRandomSelect(min_epsilon=0.0001,max_epsilon=0.0002,epsilon=0.0002)])
source=Src('CartPole-v1',agent,seed=0,steps_count=1,n_envs=1,steps_delta=1,mode='rgb_array',cbs=[GymSrc,FirstLast])
exp=[o for o,_ in zip(source,range(50))]
fig = px.imshow(torch.vstack([o['image'] for o in exp]).numpy(),animation_frame=0)
HTML(fig.to_html())
show_q_and_max_distribution(dqn.policy(torch.vstack([o['state'] for o in exp]).to(device=default_device())))
If you want to run this using multiple processess, the multiprocessing code looks like below. However you will not be able to run this in a notebook, instead add this to a py file and run it from there.
{% include warning.html content='There is a bug in data block that prevents this. Should be a simple fix.' %}